#!/usr/bin/env python

"""
Convert 8 column CSV files containing RADIUS or Diameter AVP tables
into various formats.

Format of the CSV files is one of:
- Row per 3GPP AVP tables:
    Name, Code, Section, DataType, Must, May, ShouldNot, MustNot
    - Name:
        AVP Name. String, validated as ALPHA *(ALPHA / DIGIT / "-")
        per RFC 6733 section 3.2.
        May start with a DIGIT (e.g., "3GPP-IMSI").
    - Code:
        AVP Code. Integer, 0..4294967295.
    - Section:
        Section in relevant standard. String.
    - DataType:
        AVP Data Type. String, validated per basic and derived types in:
            - RFC 6733 section 4.2
            - RFC 6733 section 4.3
            - RFC 7155 section 4.1
    - Must, May, ShouldNot, MustNot:
        Flags, possibly comma or space separated: M, P, V

- Comment row. First cell:
    # Comment text      'Comment text'
    #=                  '/*========*/'
    #                   Blank line

- Parameter row:
    @Parameter,Value [, ...]
  Supported Parameter terms:
    standard    Standard name. E.g. '3GPP TS 29.272', 'RFC 6733'.
    vendor      Vendor number.

"""

from __future__ import print_function
from __future__ import with_statement

import abc
import csv
import collections
import json
import re
import optparse
import os
import sys

CSV_COLUMN_NAMES = [
    'name',
    'code',
    'section',
    'datatype',
    'must',
    'may',
    'shouldnot',
    'mustnot',
]

DERIVED_TO_BASE = {
    'Address':          'OctetString',  # RFC 6733 section 4.3.1
    'Time':             'OctetString',  # RFC 6733 section 4.3.1
    'UTF8String':       'OctetString',  # RFC 6733 section 4.3.1
    'DiameterIdentity': 'OctetString',  # RFC 6733 section 4.3.1
    'DiameterURI':      'OctetString',  # RFC 6733 section 4.3.1
    'Enumerated':       'Integer32',    # RFC 6733 section 4.3.1
    'IPFilterRule':     'OctetString',  # RFC 6733 section 4.3.1
    'QoSFilterRule':    'OctetString',  # RFC 7155 section 4.1.1
}

# See https://www.iana.org/assignments/enterprise-numbers/enterprise-numbers
VENDOR_TO_NAME = {
    0:      '',
    193:    'Ericsson',
    8164:   'Starent',
    10415:  '3GPP',
}


class Avp(object):
    """Store an AVP row."""

    # Regex to validate avp-name per RFC 6733 section 3.2,
    # with changes:
    # - Allow avp-name to start with numbers (for 3GPP)
    # - Allow '.' in avp-name, for existing dict_dcca_3gpp usage.
# TODO: if starts with digit, ensure contains a letter somewhere?
    _name_re = re.compile(r'^[a-zA-Z0-9][a-zA-Z0-9-\.]*$')

    # Regex to validate flags: M, P, V, comma, space
    _flags_re = re.compile(r'^[MPV, ]*$')

    __slots__ = CSV_COLUMN_NAMES + [
        'filename', 'line_num', 'standard', 'vendor', ]

    def __init__(self, name, code, section, datatype,
                 must, may, shouldnot, mustnot, extra_cells=[],
                 filename='', line_num=0, standard='', vendor=0):
        # Members from CSV row
        self.name = name
        self.code = int(code)
        self.section = section
        self.datatype = datatype
        self.must = must
        self.may = may
        self.shouldnot = shouldnot
        self.mustnot = mustnot
        # Members from file state
        self.filename = filename
        self.line_num = line_num
        self.standard = standard
        self.vendor = vendor
        # Validate CSV fields
        if not self._name_re.match(self.name):
            raise ValueError('Invalid AVP name "{}"'.format(self.name))
        if (self.code < 0 or self.code > 4294967295):
            raise ValueError('AVP "{}" invalid code {}'.format(
                self.name, self.code))
        if (self.datatype not in (
                'OctetString', 'Integer32', 'Integer64', 'Unsigned32',
                'Unsigned64', 'Float32', 'Float64', 'Grouped')
                and self.datatype not in DERIVED_TO_BASE):
            raise ValueError('{} invalid data type "{}"'.format(
                self.description(), self.datatype))
        # Validate flags
        flags = collections.Counter()
        for val, desc in [
                (self.must, 'Must'),
                (self.may, 'May'),
                (self.shouldnot, 'Should Not'),
                (self.mustnot, 'Must Not'),
                ]:
            if not self._flags_re.match(val):
                raise ValueError('{} invalid {} Flags "{}"'.format(
                    self.description(), desc, val))
            flags.update(val)
        # Check occurrence of M,V in Must,May,ShouldNot,MustNot
        for flag in 'MV':
            # TODO: can AVP flags not appear at all?
            # if flags[flag] == 0:
            #     raise ValueError('{} Flag "{}" not set'.format(
            #         self.description(), flag))
            if flags[flag] > 1:
                raise ValueError('{} Flag "{}" set {} times'.format(
                    self.description(), flag, flags[flag]))
        # Compare V presence against vendor
        if 'V' in self.must:
            if self.vendor == 0:
                raise ValueError('{} Flag "V" set for vendor 0'.format(
                    self.description()))
        else:
            if self.vendor != 0:
                raise ValueError('{} Flag "V" not set for vendor {}'.format(
                    self.description(), self.vendor))

    @property
    def __dict__(self):
        return {s: getattr(self, s) for s in self.__slots__}

    def __eq__(self, other):
        """Equality comparison of Avp instances.
        Considered equal if name, vendor, code, datatype, and flags are equal.
        """
        if other is self:
            return True
        if type(other) is not type(self):
            return NotImplemented
        return (
            other.name, other.vendor, other.code, other.datatype,
            other.must, other.may, other.shouldnot, other.mustnot,
            ) == (
            self.name, self.vendor, self.code, self.datatype,
            self.must, self.may, self.shouldnot, self.mustnot,
            )

    def __ne__(self, other):
        return not self == other

    def description(self):
        return 'AVP "{}" ({})'.format(self.name, self.code)


class Processor(object):
    """Interface for processor of Avp."""

    __metaclass__ = abc.ABCMeta

    @classmethod
    def cls_name(cls):
        """Return the name, lower-case, without "processor" suffix."""
        suffix = 'processor'
        name = cls.__name__.lower()
        if name.endswith(suffix):
            return name[:-len(suffix)]
        return name

    @classmethod
    def cls_desc(cls):
        """Return the first line of the docstring."""
        if cls.__doc__ is None:
            return ""
        return cls.__doc__.split('\n')[0]

    @abc.abstractmethod
    def filename(self, filename):
        """Called when a file is opened."""
        pass

    @abc.abstractmethod
    def avp(self, avp):
        """Process a validated Avp."""
        pass

    @abc.abstractmethod
    def comment(self, comment, filename, line_num):
        """Process a comment row:
            #comment,
        """
        pass

    @abc.abstractmethod
    def generate(self):
        """Invoked after all rows processed."""
        pass

    @abc.abstractmethod
    def parameter(self, name, value):
        """Process a parameter row:
            @name,value.
        """
        pass


class DebugProcessor(Processor):
    """Display the CSV parsing."""

    def filename(self, filename):
        print('File: {}'.format(filename))

    def avp(self, avp):
        avpdict = vars(avp)
        print('AVP: {name}, {code}, {datatype}'.format(**avpdict))

    def comment(self, comment, filename, line_num):
        print('Comment: {}'.format(comment))

    def generate(self):
        print('Generate')

    def parameter(self, name, value):
        print('Parameter: {} {}'.format(name, value))


class NoopProcessor(Processor):
    """Validate the CSV; no other output."""

    def filename(self, filename):
        pass

    def avp(self, avp):
        pass

    def comment(self, comment, filename, line_num):
        pass

    def generate(self):
        pass

    def parameter(self, name, value):
        pass


class FdcProcessor(Processor):
    """Generate freeDiameter C code.

    Comment cells are parsed as:
        # text comment  /* text comment */
        #=              /*==============*/
        #               [blank line]
    """

    COMMENT_WIDTH = 64

    class AvpFunction(object):
        """Maintain per-function state to create DICT_AVP entries.
        """

        def __init__(self, name):
            self.__name = name
            self.__lines = []
            self.__derived = set()

        @property
        def name(self):
            """Return name."""
            return self.__name

        @property
        def lines(self):
            """Return all lines."""
            return self.__lines

        @lines.setter
        def lines(self, value):
            """Set to append a line."""
            self.__lines.append(value)

        @property
        def derived(self):
            """Return list of all derived values."""
            return list(self.__derived)

        @derived.setter
        def derived(self, value):
            """Set to store a derived type."""
            self.__derived.add(value)

    def __init__(self):
        self._filenames = []
        self._functions = collections.OrderedDict()

    def filename(self, filename):
        self._filenames.append(os.path.basename(filename))

    def avp(self, avp):
        comment = '{name}, {datatype}, code {code}'.format(**vars(avp))
        if avp.section != '':
            comment += ', section {}'.format(avp.section)
        self.add_comment(comment)
        self.add('\t{')
        self.add('\t\tstruct dict_avp_data data = {')
# TODO: remove comments?
        self.add('\t\t\t{},\t/* Code */'.format(avp.code))
        self.add('\t\t\t{},\t/* Vendor */'.format(avp.vendor))
        self.add('\t\t\t\"{}\",\t/* Name */'.format(avp.name))
        self.add('\t\t\t{},\t/* Fixed flags */'.format(
            self.build_flags(', '.join([avp.must, avp.mustnot]))))
        self.add('\t\t\t{},\t/* Fixed flag values */'.format(
            self.build_flags(avp.must)))
# TODO: add trailing comma?
        self.add('\t\t\tAVP_TYPE_{}\t/* base type of data */'.format(
            DERIVED_TO_BASE.get(avp.datatype, avp.datatype).upper()))
        self.add('\t\t};')
        avp_type = 'NULL'
        if avp.datatype == 'Enumerated':
            self.add('\t\tstruct dict_object\t*type;')
            vendor_prefix = ''
            if avp.vendor != 0:
                vendor_prefix = '{}/'.format(VENDOR_TO_NAME[avp.vendor])
            self.add(
                '\t\tstruct dict_type_data\t tdata = {{ AVP_TYPE_INTEGER32, '
                '"Enumerated({prefix}{name})", NULL, NULL, NULL }};'.format(
                    prefix=vendor_prefix, name=avp.name))
# XXX: add enumerated values
            self.add('\t\tCHECK_dict_new(DICT_TYPE, &tdata, NULL, &type);')
            avp_type = "type"
        elif avp.datatype in DERIVED_TO_BASE:
            avp_type = '{}_type'.format(avp.datatype)
            self.derived(avp.datatype)
        self.add('\t\tCHECK_dict_new(DICT_AVP, &data, {}, NULL);'.format(
            avp_type))
# TODO: remove ; on scope brace
        self.add('\t};')
        self.add('')

    def comment(self, comment, filename, line_num):
        if comment == '':
            self.add('')
        elif comment == '=':
            self.add_header()
        elif comment.startswith(' '):
            self.add_comment(comment[1:])
        else:
            raise ValueError('Unsupported comment "{}"'.format(comment))

    def generate(self):
        fp = sys.stdout
        self.write_introduction(fp)
        for func in self._functions:
            self.write_function(fp, self._functions[func])

    def parameter(self, name, value):
        pass

    # internal methods

    def current_avpfunction(self):
        """Return current AvpFunction to update.

        Note: allows for easier future enhancement to generate separate
        C functions per AVP groups such as: by csv file, standard, or vendor.
        """
        name = 'add_avps'
        if name not in self._functions:
            self._functions[name] = self.AvpFunction(name)
        return self._functions[name]

    def add(self, line):
        self.current_avpfunction().lines = line

    def derived(self, value):
        self.current_avpfunction().derived = value

    def build_c_token(self, value):
        """Convert a string into a valid C token."""
        return re.sub(r'[^\w]', '_', value)

    def build_flags(self, flags):
        result = []
        if 'V' in flags:
            result.append('AVP_FLAG_VENDOR')
        if 'M' in flags:
            result.append('AVP_FLAG_MANDATORY')
        if not result:
            return '0';
        return ' |'.join(result)

    def add_comment(self, comment):
        self.add(self.format_comment(comment))

    def add_header(self):
        self.add(self.format_header())

    def format_comment(self, comment):
        return '\t/* {:<{width}} */'.format(comment, width=self.COMMENT_WIDTH)

    def format_header(self):
        return '\t/*={:=<{width}}=*/'.format('', width=self.COMMENT_WIDTH)

    def write_introduction(self, fp):
        """Write the introduction to the generated file."""
        fp.write('''\
/*
Generated by:
\tcsv_to_fd -p {processor} {files}

Do not modify; modify the source .csv files instead.
*/

#include <freeDiameter/extension.h>

#define CHECK_dict_new( _type, _data, _parent, _ref ) \\
\tCHECK_FCT(  fd_dict_new( fd_g_config->cnf_dict, \
(_type), (_data), (_parent), (_ref))  );

#define CHECK_dict_search( _type, _criteria, _what, _result ) \\
\tCHECK_FCT( fd_dict_search( fd_g_config->cnf_dict, \
(_type), (_criteria), (_what), (_result), ENOENT) );
'''.format(
            processor=self.cls_name(),
            files=' '.join(self._filenames)))

    def write_function(self, fp, avpfunction):
        """Generate a function from AvpFunction."""
        function = self.build_c_token(avpfunction.name)
        # Function start
        fp.write('''\

int {}()
{{
'''.format(function))

        # Create variables used by derived type validation
        for derived in avpfunction.derived:
            fp.write('''\
\tstruct dict_object * {name}_type = NULL;
\tCHECK_dict_search(DICT_TYPE, TYPE_BY_NAME, "{name}", &{name}_type);

'''.format(name=derived))

        # Write generated DICT_AVP creation
        fp.write('\n'.join(avpfunction.lines))

        # Write function end
        fp.write('''\

\treturn 0;
}} /* {}() */
'''.format(function))


class JsonProcessor(Processor):
    """Generate freeDiameter JSON object.
    """

    def __init__(self):
        self.avps = []

    def filename(self, filename):
        pass

    def avp(self, avp):
        flags = collections.OrderedDict([
            ('Must',    self.build_flags(avp.must)),
            ('MustNot', self.build_flags(avp.mustnot)),
        ])
        row = collections.OrderedDict([
            ('Code',    avp.code),
            ('Flags',   flags),
            ('Name',    avp.name),
            ('Type',    avp.datatype),
            ('Vendor',  avp.vendor),
        ])
        self.avps.append(row)

    def comment(self, comment, filename, line_num):
        pass

    def generate(self):
        doc = {"AVPs": self.avps}
        print(json.dumps(doc, indent=2))

    def parameter(self, name, value):
        pass

    def build_flags(self, flags):
        result = []
        if 'V' in flags:
            result.append('V')
        if 'M' in flags:
            result.append('M')
        return ''.join(result)


def avp_conflict(description, avp, conflict):
    """Raise error for duplicate or conflicting AVPs.
    """
    if avp == conflict:
        raise ValueError(
            '{} {} duplicated in'
            ' file "{}" line {}'.format(
                avp.description(), description,
                conflict.filename, conflict.line_num))
    else:
        raise ValueError(
            '{} {} conflicts with {}'
            ' in file "{}" line {}'.format(
                avp.description(), description,
                conflict.description(),
                conflict.filename, conflict.line_num))


def main():
    """Main application entry.
    """

    # Build dict of name: NameProcessor
    processors = {
        cls.cls_name(): cls
        for cls in Processor.__subclasses__()
        }

    # Build Processor name to desc
    processor_help = '\n'.join(
        ['  {:8} {}'.format(key, processors[key].cls_desc())
         for key in sorted(processors)])

    # Custom OptionParser with improved help
    class MyParser(optparse.OptionParser):
        """Custom OptionParser without epilog formatting."""
        def format_help(self, formatter=None):
            return """\
{}
Supported PROCESSOR options:
{}
""".format(
                optparse.OptionParser.format_help(self, formatter),
                processor_help)

    # Parse options
    parser = MyParser(
        usage='%prog [-h] [-p PROCESSOR] FILE ...',
        description="""\
Convert CSV files FILE ... containing RADIUS or Diameter AVP tables
into various formats using the specified processor PROCESSOR.
""")
    parser.add_option(
        '-p', '--processor',
        default='noop',
        help='AVP processor. One of: {}. [%default]'.format(
             ', '.join(processors.keys())))
    (opts, args) = parser.parse_args()
    if len(args) < 1:
        parser.error('Incorrect number of arguments. Use -h for help.')

    # Find processor
    try:
        avpproc = processors[opts.processor]()
    except KeyError as e:
        parser.error('Unknown processor "{}".'.format(opts.processor))

    # dict of [vendor][code] : Avp
    avp_codes = collections.defaultdict(dict)

    # dict of [vendor][name] : Avp
    avp_names = collections.defaultdict(dict)

    # Process files
    for filename in args:
        avpproc.filename(filename)
        with open(filename, 'r') as csvfile:
            csvdata = csv.DictReader(csvfile, CSV_COLUMN_NAMES,
                                     restkey='extra_cells', restval='')
            standard = ''
            vendor = 0
            errors = []
            for row in csvdata:
                try:
                    if csvdata.restkey in row:
                        raise ValueError('Extra cells: {}'.format(
                            ','.join(row[csvdata.restkey])))
                    if row['name'] in (None, '', 'Attribute Name'):
                        continue
                    elif row['name'].startswith('#'):
                        comment = row['name'][1:]
                        avpproc.comment(comment, filename, csvdata.line_num)
                    elif row['name'].startswith('@'):
                        parameter = row['name'][1:]
                        value = row['code']
                        if False:
                            pass
                        elif parameter == 'standard':
                            standard = value
                        elif parameter == 'vendor':
                            vendor = int(value)
                        else:
                            raise ValueError('Unknown parameter "{}"'.format(
                                parameter))
                        avpproc.parameter(parameter, value)
                    else:
                        avp = Avp(filename=filename, line_num=csvdata.line_num,
                                  standard=standard, vendor=vendor,
                                  **row)
                        # Ensure AVP vendor/code not already defined
                        if avp.code in avp_codes[avp.vendor]:
                            conflict = avp_codes[avp.vendor][avp.code]
                            avp_conflict('Code', avp, conflict)
                        avp_codes[avp.vendor][avp.code] = avp
                        # Ensure AVP vendor/name not already defined
                        if avp.name in avp_names[avp.vendor]:
                            conflict = avp_names[avp.vendor][avp.name]
                            avp_conflict('Name', avp, conflict)
                        avp_names[avp.vendor][avp.name] = avp
                        # Process AVP
                        avpproc.avp(avp)
                except (TypeError, ValueError) as e:
                    errors.append('CSV file "{}" line {}: {}\n'.format(
                        filename, csvdata.line_num, e))
            if errors:
                sys.stderr.write(''.join(errors))
                sys.exit(1)

    # Generate result
    avpproc.generate()


if __name__ == '__main__':
    main()

# vim: set et sw=4 sts=4 :
